package com.example.repository;

import com.example.model.ChatMessage;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Modifying;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;

import java.time.LocalDateTime;
import java.util.List;

/**
 * 聊天消息Repository
 */
@Repository
public interface ChatMessageRepository extends JpaRepository<ChatMessage, Long> {

    /**
     * 查找两个用户之间的聊天记录
     */
    @Query("SELECT cm FROM ChatMessage cm WHERE " +
           "(cm.senderId = :userId1 AND cm.receiverId = :userId2) OR " +
           "(cm.senderId = :userId2 AND cm.receiverId = :userId1) " +
           "ORDER BY cm.createdAt ASC")
    List<ChatMessage> findChatHistory(@Param("userId1") String userId1, 
                                     @Param("userId2") String userId2);

    /**
     * 分页查找两个用户之间的聊天记录
     */
    @Query("SELECT cm FROM ChatMessage cm WHERE " +
           "(cm.senderId = :userId1 AND cm.receiverId = :userId2) OR " +
           "(cm.senderId = :userId2 AND cm.receiverId = :userId1) " +
           "ORDER BY cm.createdAt DESC")
    Page<ChatMessage> findChatHistory(@Param("userId1") String userId1, 
                                     @Param("userId2") String userId2, 
                                     Pageable pageable);

    /**
     * 查找用户的所有聊天对话列表
     */
    @Query("SELECT DISTINCT CASE " +
           "WHEN cm.senderId = :userId THEN cm.receiverId " +
           "ELSE cm.senderId END " +
           "FROM ChatMessage cm WHERE cm.senderId = :userId OR cm.receiverId = :userId")
    List<String> findChatPartners(@Param("userId") String userId);

    /**
     * 查找用户未读消息数量
     */
    @Query("SELECT COUNT(cm) FROM ChatMessage cm WHERE cm.receiverId = :userId AND cm.isRead = false")
    Long countUnreadMessages(@Param("userId") String userId);

    /**
     * 查找两个用户之间的未读消息数量
     */
    @Query("SELECT COUNT(cm) FROM ChatMessage cm WHERE cm.senderId = :senderId AND cm.receiverId = :receiverId AND cm.isRead = false")
    Long countUnreadMessagesBetween(@Param("senderId") String senderId, @Param("receiverId") String receiverId);

    /**
     * 标记消息为已读
     */
    @Modifying
    @Query("UPDATE ChatMessage cm SET cm.isRead = true, cm.readAt = :readAt WHERE cm.id = :messageId")
    void markAsRead(@Param("messageId") Long messageId, @Param("readAt") LocalDateTime readAt);

    /**
     * 批量标记消息为已读
     */
    @Modifying
    @Query("UPDATE ChatMessage cm SET cm.isRead = true, cm.readAt = :readAt WHERE cm.id IN :messageIds")
    void markAsReadBatch(@Param("messageIds") List<Long> messageIds, @Param("readAt") LocalDateTime readAt);

    /**
     * 标记两个用户之间的所有消息为已读
     */
    @Modifying
    @Query("UPDATE ChatMessage cm SET cm.isRead = true, cm.readAt = :readAt WHERE " +
           "cm.senderId = :senderId AND cm.receiverId = :receiverId AND cm.isRead = false")
    void markAllAsReadBetween(@Param("senderId") String senderId, 
                             @Param("receiverId") String receiverId, 
                             @Param("readAt") LocalDateTime readAt);

    /**
     * 查找用户的最新消息
     */
    @Query("SELECT cm FROM ChatMessage cm WHERE " +
           "(cm.senderId = :userId OR cm.receiverId = :userId) " +
           "ORDER BY cm.createdAt DESC")
    Page<ChatMessage> findLatestMessages(@Param("userId") String userId, Pageable pageable);

    /**
     * 查找指定时间范围内的消息
     */
    @Query("SELECT cm FROM ChatMessage cm WHERE " +
           "(cm.senderId = :userId1 AND cm.receiverId = :userId2) OR " +
           "(cm.senderId = :userId2 AND cm.receiverId = :userId1) " +
           "AND cm.createdAt BETWEEN :startTime AND :endTime " +
           "ORDER BY cm.createdAt ASC")
    List<ChatMessage> findMessagesBetweenTime(@Param("userId1") String userId1,
                                             @Param("userId2") String userId2,
                                             @Param("startTime") LocalDateTime startTime,
                                             @Param("endTime") LocalDateTime endTime);

    /**
     * 根据消息类型查找消息
     */
    List<ChatMessage> findByMessageTypeOrderByCreatedAtDesc(ChatMessage.MessageType messageType);

    /**
     * 查找用户发送的消息
     */
    List<ChatMessage> findBySenderIdOrderByCreatedAtDesc(String senderId);

    /**
     * 查找用户接收的消息
     */
    List<ChatMessage> findByReceiverIdOrderByCreatedAtDesc(String receiverId);

    /**
     * 删除指定时间之前的消息
     */
    @Modifying
    @Query("DELETE FROM ChatMessage cm WHERE cm.createdAt < :beforeTime")
    void deleteMessagesBeforeTime(@Param("beforeTime") LocalDateTime beforeTime);
}
